/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.io.gcp.bigquery;



import com.google.api.services.bigquery.model.TableRow;

import com.bff.gaia.unified.sdk.coders.*;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.CoderException;

import com.bff.gaia.unified.sdk.coders.StringUtf8Coder;

import com.bff.gaia.unified.sdk.coders.StructuredCoder;

import com.bff.gaia.unified.sdk.coders.VarLongCoder;

import com.bff.gaia.unified.sdk.io.gcp.bigquery.WriteBundlesToFiles.Result;

import com.bff.gaia.unified.sdk.transforms.DoFn;

import com.bff.gaia.unified.sdk.transforms.SerializableFunction;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PCollectionView;

import com.bff.gaia.unified.sdk.values.ShardedKey;

import com.bff.gaia.unified.sdk.values.TupleTag;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Lists;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Maps;



import java.io.IOException;

import java.io.InputStream;

import java.io.OutputStream;

import java.io.Serializable;

import java.util.Collections;

import java.util.List;

import java.util.Map;

import java.util.Objects;

import java.util.concurrent.ThreadLocalRandom;



import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkNotNull;



/**

 * Writes each bundle of {@link TableRow} elements out to separate file using {@link

 * TableRowWriter}. Elements destined to different destinations are written to separate files. The

 * transform will not write an element to a file if it is already writing to {@link

 * #maxNumWritersPerBundle} files and the element is destined to a new destination. In this case,

 * the element will be spilled into the output, and the {@link WriteGroupedRecordsToFiles} transform

 * will take care of writing it to a file.

 */

class WriteBundlesToFiles<DestinationT, ElementT>

    extends DoFn<KV<DestinationT, ElementT>, Result<DestinationT>> {



  // When we spill records, shard the output keys to prevent hotspots. Experiments running up to

  // 10TB of data have shown a sharding of 10 to be a good choice.

  private static final int SPILLED_RECORD_SHARDING_FACTOR = 10;



  // Map from tablespec to a writer for that table.

  private transient Map<DestinationT, TableRowWriter> writers;

  private transient Map<DestinationT, BoundedWindow> writerWindows;

  private final PCollectionView<String> tempFilePrefixView;

  private final TupleTag<KV<ShardedKey<DestinationT>, ElementT>> unwrittenRecordsTag;

  private final int maxNumWritersPerBundle;

  private final long maxFileSize;

  private final SerializableFunction<ElementT, TableRow> toRowFunction;

  private int spilledShardNumber;



  /**

   * The result of the {@link WriteBundlesToFiles} transform. Corresponds to a single output file,

   * and encapsulates the table it is destined to as well as the file byte size.

   */

  static final class Result<DestinationT> implements Serializable {

    private static final long serialVersionUID = 1L;

    public final String filename;

    public final Long fileByteSize;

    public final DestinationT destination;



    public Result(String filename, Long fileByteSize, DestinationT destination) {

      checkNotNull(destination);

      this.filename = filename;

      this.fileByteSize = fileByteSize;

      this.destination = destination;

    }



    @Override

    public boolean equals(Object other) {

      if (other instanceof Result) {

        Result<DestinationT> o = (Result<DestinationT>) other;

        return Objects.equals(this.filename, o.filename)

            && Objects.equals(this.fileByteSize, o.fileByteSize)

            && Objects.equals(this.destination, o.destination);

      }

      return false;

    }



    @Override

    public int hashCode() {

      return Objects.hash(filename, fileByteSize, destination);

    }



    @Override

    public String toString() {

      return "Result{"

          + "filename='"

          + filename

          + '\''

          + ", fileByteSize="

          + fileByteSize

          + ", destination="

          + destination

          + '}';

    }

  }



  /** a coder for the {@link Result} class. */

  public static class ResultCoder<DestinationT> extends StructuredCoder<Result<DestinationT>> {

    private static final StringUtf8Coder stringCoder = StringUtf8Coder.of();

    private static final VarLongCoder longCoder = VarLongCoder.of();

    private final Coder<DestinationT> destinationCoder;



    public static <DestinationT> ResultCoder<DestinationT> of(

        Coder<DestinationT> destinationCoder) {

      return new ResultCoder<>(destinationCoder);

    }



    ResultCoder(Coder<DestinationT> destinationCoder) {

      this.destinationCoder = destinationCoder;

    }



    @Override

    public void encode(Result<DestinationT> value, OutputStream outStream) throws IOException {

      if (value == null) {

        throw new CoderException("cannot encode a null value");

      }

      stringCoder.encode(value.filename, outStream);

      longCoder.encode(value.fileByteSize, outStream);

      destinationCoder.encode(value.destination, outStream);

    }



    @Override

    public Result<DestinationT> decode(InputStream inStream) throws IOException {

      String filename = stringCoder.decode(inStream);

      long fileByteSize = longCoder.decode(inStream);

      DestinationT destination = destinationCoder.decode(inStream);

      return new Result<>(filename, fileByteSize, destination);

    }



    @Override

    public List<? extends Coder<?>> getCoderArguments() {

      return Collections.singletonList(destinationCoder);

    }



    @Override

    public void verifyDeterministic() {}

  }



  WriteBundlesToFiles(

      PCollectionView<String> tempFilePrefixView,

      TupleTag<KV<ShardedKey<DestinationT>, ElementT>> unwrittenRecordsTag,

      int maxNumWritersPerBundle,

      long maxFileSize,

      SerializableFunction<ElementT, TableRow> toRowFunction) {

    this.tempFilePrefixView = tempFilePrefixView;

    this.unwrittenRecordsTag = unwrittenRecordsTag;

    this.maxNumWritersPerBundle = maxNumWritersPerBundle;

    this.maxFileSize = maxFileSize;

    this.toRowFunction = toRowFunction;

  }



  @StartBundle

  public void startBundle() {

    // This must be done for each bundle, as by default the {@link DoFn} might be reused between

    // bundles.

    this.writers = Maps.newHashMap();

    this.writerWindows = Maps.newHashMap();

    this.spilledShardNumber = ThreadLocalRandom.current().nextInt(SPILLED_RECORD_SHARDING_FACTOR);

  }



  TableRowWriter createAndInsertWriter(

      DestinationT destination, String tempFilePrefix, BoundedWindow window) throws Exception {

    TableRowWriter writer = new TableRowWriter(tempFilePrefix);

    writers.put(destination, writer);

    writerWindows.put(destination, window);

    return writer;

  }



  @ProcessElement

  public void processElement(

	  ProcessContext c, @Element KV<DestinationT, ElementT> element, BoundedWindow window)

      throws Exception {

    String tempFilePrefix = c.sideInput(tempFilePrefixView);

    DestinationT destination = c.element().getKey();



    TableRowWriter writer;

    if (writers.containsKey(destination)) {

      writer = writers.get(destination);

    } else {

      // Only create a new writer if we have fewer than maxNumWritersPerBundle already in this

      // bundle.

      if (writers.size() <= maxNumWritersPerBundle) {

        writer = createAndInsertWriter(destination, tempFilePrefix, window);

      } else {

        // This means that we already had too many writers open in this bundle. "spill" this record

        // into the output. It will be grouped and written to a file in a subsequent stage.

        c.output(

            unwrittenRecordsTag,

            KV.of(

                ShardedKey.of(destination, (++spilledShardNumber) % SPILLED_RECORD_SHARDING_FACTOR),

                element.getValue()));

        return;

      }

    }



    if (writer.getByteSize() > maxFileSize) {

      // File is too big. Close it and open a new file.

      writer.close();

      TableRowWriter.Result result = writer.getResult();

      c.output(new Result<>(result.resourceId.toString(), result.byteSize, destination));

      writer = createAndInsertWriter(destination, tempFilePrefix, window);

    }



    try {

      writer.write(toRowFunction.apply(element.getValue()));

    } catch (Exception e) {

      // Discard write result and close the write.

      try {

        writer.close();

        // The writer does not need to be reset, as this DoFn cannot be reused.

      } catch (Exception closeException) {

        // Do not mask the exception that caused the write to fail.

        e.addSuppressed(closeException);

      }

      throw e;

    }

  }



  @FinishBundle

  public void finishBundle(FinishBundleContext c) throws Exception {

    List<Exception> exceptionList = Lists.newArrayList();

    for (TableRowWriter writer : writers.values()) {

      try {

        writer.close();

      } catch (Exception e) {

        exceptionList.add(e);

      }

    }

    if (!exceptionList.isEmpty()) {

      Exception e = new IOException("Failed to close some writers");

      for (Exception thrown : exceptionList) {

        e.addSuppressed(thrown);

      }

      throw e;

    }



    for (Map.Entry<DestinationT, TableRowWriter> entry : writers.entrySet()) {

      try {

        DestinationT destination = entry.getKey();

        TableRowWriter writer = entry.getValue();

        TableRowWriter.Result result = writer.getResult();

        c.output(

            new Result<>(result.resourceId.toString(), result.byteSize, destination),

            writerWindows.get(destination).maxTimestamp(),

            writerWindows.get(destination));

      } catch (Exception e) {

        exceptionList.add(e);

      }

    }

    writers.clear();

  }

}